import torch
import cv2
import numpy as np
import torchvision
from torchvision.utils import save_image

from modules.tune.lightglue import LightGlue, SuperPoint, DISK, SIFT, ALIKED, DoGHardNet
from modules.tune.lightglue.utils import load_image, rbd


def init_extractor(extractor_type='superpoint', max_num_keypoints=64):
    if extractor_type == 'superpoint':
        extractor = SuperPoint(max_num_keypoints=max_num_keypoints).eval().cuda()
        matcher = LightGlue(features='superpoint').eval().cuda()
    elif extractor_type == 'disk':
        extractor = DISK(max_num_keypoints=max_num_keypoints).eval().cuda()
        matcher = LightGlue(features='disk').eval().cuda()
    elif extractor_type == 'aliked':
        extractor = ALIKED(max_num_keypoints=max_num_keypoints).eval().cuda()
        matcher = LightGlue(features='aliked').eval().cuda()
    elif extractor_type == 'sift':
        extractor = SIFT(max_num_keypoints=max_num_keypoints).eval().cuda()
        matcher = LightGlue(features='sift').eval().cuda()
    else:
        raise ValueError("Unsupported extractor type. Use 'superpoint', 'disk', 'aliked', or 'sift'.")
        
    return { 
             'extractor': extractor, 
             'matcher': matcher
    }


def calculate_matching_loss(gt, render, extractor, matcher, max_num_keypoints=64, lambda_count=0.1):


        
    image0 = gt
    image1 = render
    

    feats0 = extractor.extract(image0)  
    feats1 = extractor.extract(image1)  

    
    matches01 = matcher({'image0': feats0, 'image1': feats1})
    feats0, feats1, matches01 = [rbd(x) for x in [feats0, feats1, matches01]]  
    matches = matches01['matches']  

    
    points0 = feats0['keypoints'][matches[..., 0]]  
    points1 = feats1['keypoints'][matches[..., 1]] 
     
    #print(f"points0", points0)
    #print(f"points1", points1)
    
    dist = torch.norm(points0 - points1, dim=1).cuda()

    
    scores = matches01['scores'].cuda()  

    
    L_match = torch.sum(scores * dist).cuda()

    
    N_render = torch.Tensor([len(points0)])
    N_gt = torch.Tensor([len(points1)])
    #print(f"N_render", N_render)
    #print(f"N_gt", N_gt)
    L_count = torch.abs(N_render - N_gt).cuda()

    #total_loss = L_match + lambda_count * L_count
    total_loss = L_match

    return total_loss


    #loss = calculate_matching_loss(gt_image_path, render_image_path, extractor_type='superpoint', max_num_keypoints=2048, lambda_count=0.1)
    #print(f"Matching loss: {loss}")

def visualize_matches(gt, render, extractor_type='superpoint', max_num_keypoints=2048, output_path="matched_images.png"):


    if extractor_type == 'superpoint':
        extractor = SuperPoint(max_num_keypoints=max_num_keypoints).eval().cuda()
        matcher = LightGlue(features='superpoint').eval().cuda()
    elif extractor_type == 'disk':
        extractor = DISK(max_num_keypoints=max_num_keypoints).eval().cuda()
        matcher = LightGlue(features='disk').eval().cuda()
    elif extractor_type == 'aliked':
        extractor = ALIKED(max_num_keypoints=max_num_keypoints).eval().cuda()
        matcher = LightGlue(features='aliked').eval().cuda()
    elif extractor_type == 'sift':
        extractor = SIFT(max_num_keypoints=max_num_keypoints).eval().cuda()
        matcher = LightGlue(features='sift').eval().cuda()
    else:
        raise ValueError("Unsupported extractor type. Use 'superpoint', 'disk', 'aliked', or 'sift'.")

    image0 = gt
    image1 = render

    feats0 = extractor.extract(image0)
    feats1 = extractor.extract(image1)

    matches01 = matcher({'image0': feats0, 'image1': feats1})
    feats0, feats1, matches01 = [rbd(x) for x in [feats0, feats1, matches01]]
    matches = matches01['matches']

    points0 = feats0['keypoints'][matches[..., 0]]
    points1 = feats1['keypoints'][matches[..., 1]]
    print(f"points0", points0)
    print(f"points1", points1)

    _, h, w = image0.shape


    concatenated_images = torch.cat((image0, image1), dim=2)


    matched_image = concatenated_images.clone()


    points1[:, 0] += w  

    for i in range(len(points0)):
        x0, y0 = int(points0[i, 0]), int(points0[i, 1])
        x1, y1 = int(points1[i, 0]), int(points1[i, 1])

        matched_image[:, y0-1:y0+1, x0-1:x0+1] = torch.tensor([1.0, 0.0, 0.0]).view(1, 3, 1, 1).cuda()
        matched_image[:, y1-1:y1+1, x1-1:x1+1] = torch.tensor([0.0, 1.0, 0.0]).view(1, 3, 1, 1).cuda()

    save_image(matched_image, output_path)
    #print(f"Matched image saved to {output_path}")
    
    
    

    
    
    
